{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "I17bSxxg1evl"
   },
   "source": [
    "# Signature-Aware Model Serving from MLflow with Ray Serve\n",
    "\n",
    "_Authored by: [Jonathan Jin](https://huggingface.co/jinnovation)_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "IuS0daXP1lIa"
   },
   "source": [
    "## Introduction\n",
    "\n",
    "This notebook explores solutions for streamlining the deployment of models from a model registry. For teams that want to productionize many models over time, investments at this \"transition point\" in the AI/ML project lifecycle can meaningfully drive down time-to-production. This can be important for a younger, smaller team that may not have the benefit of existing infrastructure to form a \"golden path\" for serving online models in production.\n",
    "\n",
    "## Motivation\n",
    "\n",
    "Optimizing this stage of the model lifecycle is particularly important due to the production-facing aspect of the end result. At this stage, your model becomes, in effect, a microservice. This means that you now need to contend with all elements of service ownership, which can include:\n",
    "\n",
    "- Standardizing and enforcing API backwards-compatibility;\n",
    "- Logging, metrics, and general observability concerns;\n",
    "- Etc.\n",
    "\n",
    "Needing to repeat the same general-purpose setup each time you want to deploy a new model will result in development costs adding up significantly over time for you and your team. On the flip side, given the \"long tail\" of production-model ownership (assuming a productionized model is not likely to be decommissioned anytime soon), streamlining investments here can pay healthy dividends over time.\n",
    "\n",
    "Given all of the above, we motivate our exploration here with the following user story:\n",
    "\n",
    "> I would like to deploy a model from a model registry (such as [MLflow](https://mlflow.org/)) using **only the name of the model**. The less boilerplate and scaffolding that I need to replicate each time I want to deploy a new model, the better. I would like the ability to dynamically select between different versions of the model without needing to set up a whole new deployment to accommodate those new versions.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fXlB7AJr2foY"
   },
   "source": [
    "## Components\n",
    "\n",
    "For our exploration here, we'll use the following minimal stack:\n",
    "\n",
    "- MLflow for model registry;\n",
    "- Ray Serve for model serving.\n",
    "\n",
    "For demonstrative purposes, we'll exclusively use off-the-shelf open-source models from Hugging Face Hub.\n",
    "\n",
    "We will **not** use GPUs for inference because inference performance is orthogonal to our focus here today. Needless to say, in \"real life,\" you will likely not be able to get away with serving your model with CPU compute.\n",
    "\n",
    "Let's install our dependencies now."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "HfLQGO6E2hnW",
    "outputId": "c9634e63-5aaf-4e59-e970-aecb36d25b77"
   },
   "outputs": [],
   "source": [
    "!pip install \"transformers\" \"mlflow-skinny\" \"ray[serve]\" \"torch\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "C0UziXBN4Szc"
   },
   "source": [
    "## Register the Model\n",
    "\n",
    "First, let's define the model that we'll use for our exploration today. For simplicity's sake, we'll use a simple text translation model, where the source and destination languages are configurable at registration time. In effect, this means that different \"versions\" of the model can be registered to translate different languages, but the underlying model architecture and weights can stay the same."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "D2HsBFUa4nBM"
   },
   "outputs": [],
   "source": [
    "import mlflow\n",
    "from transformers import pipeline\n",
    "\n",
    "class MyTranslationModel(mlflow.pyfunc.PythonModel):\n",
    "    def load_context(self, context):\n",
    "        self.lang_from = context.model_config.get(\"lang_from\", \"en\")\n",
    "        self.lang_to = context.model_config.get(\"lang_to\", \"de\")\n",
    "\n",
    "        self.input_label: str = context.model_config.get(\"input_label\", \"prompt\")\n",
    "\n",
    "        self.model_ref: str = context.model_config.get(\"hfhub_name\", \"google-t5/t5-base\")\n",
    "\n",
    "        self.pipeline = pipeline(\n",
    "            f\"translation_{self.lang_from}_to_{self.lang_to}\",\n",
    "            self.model_ref,\n",
    "        )\n",
    "\n",
    "    def predict(self, context, model_input, params=None):\n",
    "        prompt = model_input[self.input_label].tolist()\n",
    "\n",
    "        return self.pipeline(prompt)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-PFbVlpdIBHA"
   },
   "source": [
    "(You might be wondering why we even bothered making the input label configurable. This will be useful to us later.)\n",
    "\n",
    "Now that our model is defined, let's register an actual version of it. This particular version will use Google's [T5 Base](https://huggingface.co/google-t5/t5-base) model and be configured to translate from **English** to **German**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "SpGCrnAx6eVf",
    "outputId": "11218a74-11fa-471b-cc86-03a150b64f20"
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "with mlflow.start_run():\n",
    "    model_info = mlflow.pyfunc.log_model(\n",
    "        \"translation_model\",\n",
    "        registered_model_name=\"translation_model\",\n",
    "        python_model=MyTranslationModel(),\n",
    "        pip_requirements=[\"transformers\"],\n",
    "        input_example=pd.DataFrame({\n",
    "            \"prompt\": [\"Hello my name is Jonathan.\"],\n",
    "        }),\n",
    "        model_config={\n",
    "            \"hfhub_name\": \"google-t5/t5-base\",\n",
    "            \"lang_from\": \"en\",\n",
    "            \"lang_to\": \"de\",\n",
    "        },\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "NaUwo6E0DPbI"
   },
   "source": [
    "Let's keep track of this exact version. This will be useful later."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "e0o4ICh38Pjy"
   },
   "outputs": [],
   "source": [
    "en_to_de_version: str = str(model_info.registered_model_version)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Jn0RU7fXDTdD"
   },
   "source": [
    "The registered model metadata contains some useful information for us. Most notably, the registered model version is associated with a strict **signature** that denotes the expected shape of its input and output. This will be useful to us later."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ZKMgYR_jDhOA",
    "outputId": "7f1410df-cde3-4160-eee8-30788a402b3b",
    "tags": [
     "keep_output"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "inputs: \n",
      "  ['prompt': string (required)]\n",
      "outputs: \n",
      "  ['translation_text': string (required)]\n",
      "params: \n",
      "  None\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(model_info.signature)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "iwa3o-0B9FPO"
   },
   "source": [
    "## Serve the Model\n",
    "\n",
    "Now that our model is registered in MLflow, let's set up our serving scaffolding using [Ray Serve](https://docs.ray.io/en/latest/serve/index.html). For now, we'll limit our \"deployment\" to the following behavior:\n",
    "\n",
    "- Source the seleted model and version from MLflow;\n",
    "- Receive inference requests and return inference responses via a simple REST API."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7OZ2lqOS9oqw"
   },
   "outputs": [],
   "source": [
    "import mlflow\n",
    "import pandas as pd\n",
    "\n",
    "from ray import serve\n",
    "from fastapi import FastAPI\n",
    "\n",
    "app = FastAPI()\n",
    "\n",
    "@serve.deployment\n",
    "@serve.ingress(app)\n",
    "class ModelDeployment:\n",
    "    def __init__(self, model_name: str = \"translation_model\", default_version: str = \"1\"):\n",
    "        self.model_name = model_name\n",
    "        self.default_version = default_version\n",
    "\n",
    "        self.model = mlflow.pyfunc.load_model(f\"models:/{self.model_name}/{self.default_version}\")\n",
    "\n",
    "\n",
    "    @app.post(\"/serve\")\n",
    "    async def serve(self, input_string: str):\n",
    "        return self.model.predict(pd.DataFrame({\"prompt\": [input_string]}))\n",
    "\n",
    "deployment = ModelDeployment.bind(default_version=en_to_de_version)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "f018wd2fEia7"
   },
   "source": [
    "You might have notice that hard-coding `\"prompt\"` as the input label here introduces hidden coupling between the registered model's signature and the deployment implementation. We'll come back to this later.\n",
    "\n",
    "Now, let's run the deployment and play around with it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "MudMnivd_DrC",
    "outputId": "7f23394f-9f3e-4ce1-c67a-82c59a5bc25f"
   },
   "outputs": [],
   "source": [
    "serve.run(deployment, blocking=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "VTk1E5pp_gRz",
    "outputId": "67a20366-f637-4a0a-8c51-0f71bf5e1ea6",
    "tags": [
     "keep_output"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'translation_text': 'Das Wetter ist heute nett.'}]\n"
     ]
    }
   ],
   "source": [
    "import requests\n",
    "\n",
    "response = requests.post(\n",
    "    \"http://127.0.0.1:8000/serve/\",\n",
    "    params={\"input_string\": \"The weather is lovely today\"},\n",
    ")\n",
    "\n",
    "print(response.json())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "i3CNI-mmE_22"
   },
   "source": [
    "This works fine, but you might have noticed that the REST API does not line up with the model signature. Namely, it uses the label `\"input_string\"` while the served model version itself uses the input label `\"prompt\"`. Similarly, the model can accept multiple inputs values, but the API only accepts one.\n",
    "\n",
    "If this feels [smelly](https://en.wikipedia.org/wiki/Code_smell) to you, keep reading; we'll come back to this."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hsJ65rNNDMVj"
   },
   "source": [
    "## Multiple Versions, One Endpoint\n",
    "\n",
    "Now we've got a basic endpoint set up for our model. Great! However, notice that this deployment is strictly tethered to a single version of this model -- specifically, version `1` of the registered `translation_model`.\n",
    "\n",
    "Imagine, now, that your team would like to come back and refine this model -- maybe retrain it on new data, or configure it to translate to a new language, e.g. French instead of German. Both would result in a new version of the `translation_model` getting registered. However, with our current deployment implementation, we'd need to set up a whole new endpoint for `translation_model/2`, require our users to remember which address and port corresponds to which version of the model, and so on. In other words: very cumbersome, very error-prone, very [toilsome](https://leaddev.com/velocity/what-toil-and-why-it-damaging-your-engineering-org).\n",
    "\n",
    "Conversely, imagine a scenario where we could reuse the exact same endpoint -- same signature, same address and port, same query conventions, etc. -- to serve both versions of this model. Our user can simply specify which version of the model they'd like to use, and we can treat one of them as the \"default\" in cases where the user didn't explicitly request one.\n",
    "\n",
    "This is one area where Ray Serve shines with a feature it calls [model multiplexing](https://docs.ray.io/en/latest/serve/model-multiplexing.html). In effect, this allows you to load up multiple \"versions\" of your model, dynamically hot-swapping them as needed, as well as unloading the versions that don't get used after some time. Very space-efficient, in other words.\n",
    "\n",
    "Let's try registering another version of the model -- this time, one that translates from English to French. We'll register this under the version `\"2\"`; the model server will retrieve the model version that way.\n",
    "\n",
    "But first, let's extend the model server with multiplexing support."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "d8GcI3WLE3Sc"
   },
   "outputs": [],
   "source": [
    "from ray import serve\n",
    "from fastapi import FastAPI\n",
    "\n",
    "app = FastAPI()\n",
    "\n",
    "@serve.deployment\n",
    "@serve.ingress(app)\n",
    "class MultiplexedModelDeployment:\n",
    "\n",
    "    @serve.multiplexed(max_num_models_per_replica=2)\n",
    "    async def get_model(self, version: str):\n",
    "        return mlflow.pyfunc.load_model(f\"models:/{self.model_name}/{version}\")\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        model_name: str = \"translation_model\",\n",
    "        default_version: str = en_to_de_version,\n",
    "    ):\n",
    "        self.model_name = model_name\n",
    "        self.default_version = default_version\n",
    "\n",
    "    @app.post(\"/serve\")\n",
    "    async def serve(self, input_string: str):\n",
    "        model = await self.get_model(serve.get_multiplexed_model_id())\n",
    "        return model.predict(pd.DataFrame({\"prompt\": [input_string]}))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "f-gisRU_FKlJ",
    "outputId": "a0c7318d-8271-4163-d58d-9ed97df72266"
   },
   "outputs": [],
   "source": [
    "multiplexed_deployment = MultiplexedModelDeployment.bind(model_name=\"translation_model\")\n",
    "serve.run(multiplexed_deployment, blocking=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Qs7snXhxdlUR"
   },
   "source": [
    "Now let's actually register the new model version."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "K3_essFBEuCo",
    "outputId": "b7f4f9e7-62bf-40ae-ed8a-db0110ad2e4f"
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "with mlflow.start_run():\n",
    "    model_info = mlflow.pyfunc.log_model(\n",
    "        \"translation_model\",\n",
    "        registered_model_name=\"translation_model\",\n",
    "        python_model=MyTranslationModel(),\n",
    "        pip_requirements=[\"transformers\"],\n",
    "        input_example=pd.DataFrame({\n",
    "            \"prompt\": [\n",
    "                \"Hello my name is Jon.\",\n",
    "            ],\n",
    "        }),\n",
    "        model_config={\n",
    "            \"hfhub_name\": \"google-t5/t5-base\",\n",
    "            \"lang_from\": \"en\",\n",
    "            \"lang_to\": \"fr\",\n",
    "        },\n",
    "    )\n",
    "\n",
    "en_to_fr_version: str = str(model_info.registered_model_version)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rxOzkg65dnZW"
   },
   "source": [
    "Now that that's registered, we can query for it via the model server like so..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "EyeLmnPJFuRH",
    "outputId": "9dfb8df0-f207-42ae-b78b-db51d8843c15",
    "tags": [
     "keep_output"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'translation_text': \"Le temps est beau aujourd'hui\"}]\n"
     ]
    }
   ],
   "source": [
    "import requests\n",
    "\n",
    "response = requests.post(\n",
    "    \"http://127.0.0.1:8000/serve/\",\n",
    "    params={\"input_string\": \"The weather is lovely today\"},\n",
    "    headers={\"serve_multiplexed_model_id\": en_to_fr_version},\n",
    ")\n",
    "\n",
    "print(response.json())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jVMCS4CedudN"
   },
   "source": [
    "Note how we were able to immediately access the model version **without redeploying the model server**. Ray Serve's multiplexing capabilities allow it to dynamically fetch the model weights in a just-in-time fashion; if I never requested version 2, it never gets loaded. This helps conserve compute resources for the models that **do** get queried. What's even more useful is that, if the number of models loaded up exceeds the configured maximum (`max_num_models_per_replica`), the [least-recently used model version will get evicted](https://docs.ray.io/en/latest/serve/model-multiplexing.html#why-model-multiplexing).\n",
    "\n",
    "Given that we set `max_num_models_per_replica=2` above, the \"default\" English-to-German version of the model should still be loaded up and readily available to serve requests without any cold-start time. Let's confirm that now:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "jEJFQNlwGGKh",
    "outputId": "b847d92e-fe0f-4439-bd87-e5773680c4d1",
    "tags": [
     "keep_output"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'translation_text': 'Das Wetter ist heute nett.'}]\n"
     ]
    }
   ],
   "source": [
    "print(\n",
    "    requests.post(\n",
    "        \"http://127.0.0.1:8000/serve/\",\n",
    "        params={\"input_string\": \"The weather is lovely today\"},\n",
    "        headers={\"serve_multiplexed_model_id\": en_to_de_version},\n",
    "    ).json()\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "D8CgPXcsIg5C"
   },
   "source": [
    "## Auto-Signature\n",
    "\n",
    "This is all well and good. However, notice that the following friction point still exists: when defining the server, we need to define a whole new signature for the API itself. At best, this is just some code duplication of the model signature itself (which is registered in MLflow). At worst, this can result in inconsistent APIs across all models that your team or organization owns, which can cause confusion and frustration in your downstream dependencies.\n",
    "\n",
    "In this particular case, it means that `MultiplexedModelDeployment` is secretly actually **tightly coupled** to the use-case for `translation_model`. What if we wanted to deploy another set of models that don't have to do with language translation? The defined `/serve` API, which returns a JSON object that looks like `{\"translated_text\": \"foo\"}`, would no longer make sense.\n",
    "\n",
    "To address this issue, **what if the API signature for `MultiplexedModelDeployment` could automatically mirror the signature of the underlying models it's serving**?\n",
    "\n",
    "Thankfully, with MLflow Model Registry metadata and some Python dynamic-class-creation shenanigans, this is entirely possible.\n",
    "\n",
    "Let's set things up so that the model server signature is inferred from the registered model itself. Since different versions of an MLflow can have different signatures, we'll use the \"default version\" to \"pin\" the signature; any attempt to multiplex an incompatible-signature model version we will have throw an error.\n",
    "\n",
    "Since Ray Serve binds the request and response signatures at class-definition time, we will use a Python metaclass to set this as a function of the specified model name and default model version."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "u9GPbQrnP7OD"
   },
   "outputs": [],
   "source": [
    "import mlflow\n",
    "import pydantic\n",
    "\n",
    "def schema_to_pydantic(schema: mlflow.types.schema.Schema, *, name: str) -> pydantic.BaseModel:\n",
    "    return pydantic.create_model(\n",
    "        name,\n",
    "        **{\n",
    "            k: (v.type.to_python(), pydantic.Field(required=True))\n",
    "            for k, v in schema.input_dict().items()\n",
    "        }\n",
    "    )\n",
    "\n",
    "def get_req_resp_signatures(model_signature: mlflow.models.ModelSignature) -> tuple[pydantic.BaseModel, pydantic.BaseModel]:\n",
    "    inputs: mlflow.types.schema.Schema = model_signature.inputs\n",
    "    outputs: mlflow.types.schema.Schema = model_signature.outputs\n",
    "\n",
    "    return (schema_to_pydantic(inputs, name=\"InputModel\"), schema_to_pydantic(outputs, name=\"OutputModel\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "PgetOY1LKp6m",
    "outputId": "ada066e3-72b3-42af-c284-41118fcb2e20"
   },
   "outputs": [],
   "source": [
    "import mlflow\n",
    "\n",
    "from fastapi import FastAPI, Response, status\n",
    "from ray import serve\n",
    "from typing import List\n",
    "\n",
    "def deployment_from_model_name(model_name: str, default_version: str = \"1\"):\n",
    "    app = FastAPI()\n",
    "    model_info = mlflow.models.get_model_info(f\"models:/{model_name}/{default_version}\")\n",
    "    input_datamodel, output_datamodel = get_req_resp_signatures(model_info.signature)\n",
    "\n",
    "    @serve.deployment\n",
    "    @serve.ingress(app)\n",
    "    class DynamicallyDefinedDeployment:\n",
    "\n",
    "        MODEL_NAME: str = model_name\n",
    "        DEFAULT_VERSION: str = default_version\n",
    "\n",
    "        @serve.multiplexed(max_num_models_per_replica=2)\n",
    "        async def get_model(self, model_version: str):\n",
    "            model = mlflow.pyfunc.load_model(f\"models:/{self.MODEL_NAME}/{model_version}\")\n",
    "\n",
    "            if model.metadata.get_model_info().signature != model_info.signature:\n",
    "                raise ValueError(f\"Requested version {model_version} has signature incompatible with that of default version {self.DEFAULT_VERSION}\")\n",
    "            return model\n",
    "\n",
    "        # TODO: Extend this to support batching (lists of inputs and outputs)\n",
    "        @app.post(\"/serve\", response_model=List[output_datamodel])\n",
    "        async def serve(self, model_input: input_datamodel, response: Response):\n",
    "            model_id = serve.get_multiplexed_model_id()\n",
    "            if model_id == \"\":\n",
    "                model_id = self.DEFAULT_VERSION\n",
    "\n",
    "            try:\n",
    "                model = await self.get_model(model_id)\n",
    "            except ValueError:\n",
    "                response.status_code = status.HTTP_409_CONFLICT\n",
    "                return [{\"translation_text\": \"FAILED\"}]\n",
    "\n",
    "            return model.predict(model_input.dict())\n",
    "\n",
    "    return DynamicallyDefinedDeployment\n",
    "\n",
    "deployment = deployment_from_model_name(\"translation_model\", default_version=en_to_fr_version)\n",
    "\n",
    "serve.run(deployment.bind(), blocking=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "x911zDhomWMj",
    "outputId": "7dc78df7-4f06-4871-d45f-37cfb852ffc5",
    "tags": [
     "keep_output"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'translation_text': \"Le temps est beau aujourd'hui\"}]\n"
     ]
    }
   ],
   "source": [
    "import requests\n",
    "\n",
    "resp = requests.post(\n",
    "    \"http://127.0.0.1:8000/serve/\",\n",
    "    json={\"prompt\": \"The weather is lovely today\"},\n",
    ")\n",
    "\n",
    "assert resp.ok\n",
    "assert resp.status_code == 200\n",
    "\n",
    "print(resp.json())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "EX7ff2wg5PjL",
    "outputId": "edf0587a-abf5-4160-a621-f9ac4faee6bf",
    "tags": [
     "keep_output"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'translation_text': \"Le temps est beau aujourd'hui\"}]\n"
     ]
    }
   ],
   "source": [
    "import requests\n",
    "\n",
    "resp = requests.post(\n",
    "    \"http://127.0.0.1:8000/serve/\",\n",
    "    json={\"prompt\": \"The weather is lovely today\"},\n",
    "    headers={\"serve_multiplexed_model_id\": str(en_to_fr_version)},\n",
    ")\n",
    "\n",
    "assert resp.ok\n",
    "assert resp.status_code == 200\n",
    "\n",
    "print(resp.json())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kwkDDzebG_dd"
   },
   "source": [
    "Let's now confirm that the signature-check provision we put in place actually works. For this, let's register this same model with a **slightly** different signature. This should be enough to trigger the failsafe.\n",
    "\n",
    "(Remember when we made the input label configurable at the start of this exercise? This is where that finally comes into play. 😎)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "JYydMogXHsOJ",
    "outputId": "d8cd96f0-58d2-462b-8902-d9a65b604dc0"
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "with mlflow.start_run():\n",
    "    incompatible_version = str(mlflow.pyfunc.log_model(\n",
    "        \"translation_model\",\n",
    "        registered_model_name=\"translation_model\",\n",
    "        python_model=MyTranslationModel(),\n",
    "        pip_requirements=[\"transformers\"],\n",
    "        input_example=pd.DataFrame({\n",
    "            \"text_to_translate\": [\n",
    "                \"Hello my name is Jon.\",\n",
    "            ],\n",
    "        }),\n",
    "        model_config={\n",
    "            \"input_label\": \"text_to_translate\",\n",
    "            \"hfhub_name\": \"google-t5/t5-base\",\n",
    "            \"lang_from\": \"en\",\n",
    "            \"lang_to\": \"de\",\n",
    "        },\n",
    "    ).registered_model_version)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "5Yn-5VlIH6gs",
    "outputId": "e22f1791-b013-445c-a2ab-08916c5c1032"
   },
   "outputs": [],
   "source": [
    "import requests\n",
    "\n",
    "resp = requests.post(\n",
    "    \"http://127.0.0.1:8000/serve/\",\n",
    "    json={\"prompt\": \"The weather is lovely today\"},\n",
    "    headers={\"serve_multiplexed_model_id\": incompatible_version},\n",
    ")\n",
    "assert not resp.ok\n",
    "resp.status_code == 409\n",
    "\n",
    "assert resp.json()[0][\"translation_text\"] == \"FAILED\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "DMhjLZh-jCVa"
   },
   "source": [
    "(The technically \"correct\" thing to do here would be to implement a response container that allows for an \"error message\" to be defined as part of the actual response, rather than \"abusing\" the `translation_text` field like we do here. For demonstration purposes, however, this'll do.)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "cCLtQCgsjwPM"
   },
   "source": [
    "To fully close things out, let's try registering an entirely different model -- with an entirely different signature -- and deploying that via `deployment_from_model_name()`. This will help us confirm that the entire signature is defined from the loaded model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fXUPRszjIGYN"
   },
   "outputs": [],
   "source": [
    "import mlflow\n",
    "from transformers import pipeline\n",
    "\n",
    "class QuestionAnswererModel(mlflow.pyfunc.PythonModel):\n",
    "    def load_context(self, context):\n",
    "\n",
    "        self.model_context = context.model_config.get(\n",
    "            \"model_context\",\n",
    "            \"My name is Hans and I live in Germany.\",\n",
    "        )\n",
    "        self.model_name = context.model_config.get(\n",
    "            \"model_name\",\n",
    "            \"deepset/roberta-base-squad2\",\n",
    "        )\n",
    "\n",
    "        self.tokenizer_name = context.model_config.get(\n",
    "            \"tokenizer_name\",\n",
    "            \"deepset/roberta-base-squad2\",\n",
    "        )\n",
    "\n",
    "        self.pipeline = pipeline(\n",
    "            \"question-answering\",\n",
    "            model=self.model_name,\n",
    "            tokenizer=self.tokenizer_name,\n",
    "        )\n",
    "\n",
    "    def predict(self, context, model_input, params=None):\n",
    "        resp = self.pipeline(\n",
    "            question=model_input[\"question\"].tolist(),\n",
    "            context=self.model_context,\n",
    "        )\n",
    "\n",
    "        return [resp] if type(resp) is not list else resp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "_p4FrmmhPAuq",
    "outputId": "d5293b38-e56b-4b3f-c4e1-9906ba9c4383"
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "with mlflow.start_run():\n",
    "    model_info = mlflow.pyfunc.log_model(\n",
    "        \"question_answerer\",\n",
    "        registered_model_name=\"question_answerer\",\n",
    "        python_model=QuestionAnswererModel(),\n",
    "        pip_requirements=[\"transformers\"],\n",
    "        input_example=pd.DataFrame({\n",
    "            \"question\": [\n",
    "                \"Where do you live?\",\n",
    "                \"What is your name?\",\n",
    "            ],\n",
    "        }),\n",
    "        model_config={\n",
    "            \"model_context\": \"My name is Hans and I live in Germany.\",\n",
    "        },\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "g-0mQytrKyOc",
    "outputId": "dd59ef90-ed96-490a-c27f-8f5dbc023ed3",
    "tags": [
     "keep_output"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "inputs: \n",
      "  ['question': string (required)]\n",
      "outputs: \n",
      "  ['score': double (required), 'start': long (required), 'end': long (required), 'answer': string (required)]\n",
      "params: \n",
      "  None\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(model_info.signature)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "afpSjdgYPaCw",
    "outputId": "b01dcf25-289c-4ed6-f878-172966e88438"
   },
   "outputs": [],
   "source": [
    "from ray import serve\n",
    "\n",
    "serve.run(\n",
    "    deployment_from_model_name(\n",
    "        \"question_answerer\",\n",
    "        default_version=str(model_info.registered_model_version),\n",
    "    ).bind(),\n",
    "    blocking=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "MsLq5vbsS84T",
    "outputId": "73489ce0-984b-4915-e8e0-27db7a8966ec",
    "tags": [
     "keep_output"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'score': 3.255764386267401e-05, 'start': 30, 'end': 38, 'answer': 'Germany.'}]\n"
     ]
    }
   ],
   "source": [
    "import requests\n",
    "\n",
    "resp = requests.post(\n",
    "    \"http://127.0.0.1:8000/serve/\",\n",
    "    json={\"question\": \"The weather is lovely today\"},\n",
    ")\n",
    "print(resp.json())\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "In this notebook, we've leveraged MLflow's built-in support for tracking model signatures to heavily streamline the process of deploying an HTTP server to serve that model in online fashion. We've taken Ray Serve's powerful-but-fiddly primitives to empower ourselves to, in one line, deploy a model server with:\n",
    "\n",
    "- Version multiplexing;\n",
    "- Automatic REST API signature setup;\n",
    "- Safeguards to prevent use of model versions with incompatible signatures.\n",
    "\n",
    "In doing so, we've demonstrated Ray Serve's value and potential as a toolkit upon which you and your team can [\"build your own ML platform\"](https://docs.ray.io/en/latest/serve/index.html#how-does-serve-compare-to).\n",
    "\n",
    "We've also demonstrated ways to reduce the integration overhead and toil associated with using multiple tools in combination with each other. Seamless integration is a powerful argument in favor of self-contained all-encompassing platforms such as AWS Sagemaker or GCP Vertex AI. We've demonstrated that, with a little clever engineering and principled eye towards the friction points that users -- in this case, MLEs -- care about, we can reap similar benefits without tethering ourselves and our team to expensive vendor contracts.\n",
    "\n",
    "## Exercises\n",
    "\n",
    "- The generated API signature is **very similar** to the model signature, but there's still some mismatch. Can you identify where it is? Try fixing it. Hint: What happens when you try passing in multiple questions to the question-answerer endpoint we set up?\n",
    "- MLflow model signatures allow for [optional inputs](https://mlflow.org/docs/latest/model/signatures.html#required-vs-optional-input-fields). Our current implementation does not account for this. How might we extend the implementation here to support optional inputs?\n",
    "- Similarly, MLflow model signatures allow for non-input [\"inference parameters\"](https://mlflow.org/docs/latest/model/signatures.html#model-signatures-with-inference-params), which our current implementation also does not support. How might we extend our implementation here to support inference parameters?\n",
    "- We use the name `DynamicallyDefinedDeployment` every single time we generate a new deployment, regardless of what model name and version we pass in. Is this a problem? If so, what kind of issues do you foresee this approach creating? Try tweaking `deployment_from_model_name()` to handle those issues."
   ]
  }
 ],
 "metadata": {
  "colab": {
   "authorship_tag": "ABX9TyPKW7x903JxiHL2pqDZChKh",
   "include_colab_link": true,
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
